--- title: Time Alignment with micro-tcn keywords: fastai sidebar: home_sidebar nb_path: "02_time_align.ipynb" ---
{% raw %}
{% endraw %}

Work in progress for NASH Hackathon, Dec 17, 2021

this is like the 01_td_demo notebook only we use a different dataset and generalize the dataloader a bit

Installs and imports

{% raw %}
# Next line only executes on Colab. Colab users: Please enable GPU in Edit > Notebook settings
! [ -e /content ] && pip install -Uqq pip fastai git+https://github.com/drscotthawley/fastproaudio.git

# Additional installs for this tutorial
%pip install -q fastai_minima torchsummary pyzenodo3 wandb

# Install micro-tcn and auraloss packages (from source, will take a little while)
%pip install -q wheel --ignore-requires-python git+https://github.com/csteinmetz1/micro-tcn.git  git+https://github.com/csteinmetz1/auraloss

# After this cell finishes, restart the kernel and continue below
WARNING: You are using pip version 21.3; however, version 21.3.1 is available.
You should consider upgrading via the '/home/shawley/envs/fastai/bin/python -m pip install --upgrade pip' command.
Note: you may need to restart the kernel to use updated packages.
  WARNING: Missing build requirements in pyproject.toml for git+https://github.com/csteinmetz1/auraloss.
  WARNING: The project does not specify a build backend, and pip cannot fall back to setuptools without 'wheel'.
WARNING: You are using pip version 21.3; however, version 21.3.1 is available.
You should consider upgrading via the '/home/shawley/envs/fastai/bin/python -m pip install --upgrade pip' command.
Note: you may need to restart the kernel to use updated packages.
{% endraw %} {% raw %}
from fastai.vision.all import *
from fastai.text.all import *
from fastai.callback.fp16 import *
import wandb
from fastai.callback.wandb import *
import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
from IPython.display import Audio 
import matplotlib.pyplot as plt
import torchsummary
from fastproaudio.core import *
from pathlib import Path
from glob import glob
import json
import re 
{% endraw %}

Check out the Data

jacob's working on generating the dataset(s). Probably he'll put it in private Dropbox.

{% raw %}
path = Path('wherever jacob puts the data')


fnames_in = sorted(glob(str(path)+'/*/input*'))
fnames_targ = sorted(glob(str(path)+'/*/*targ*'))
ind = -1   # pick one spot in the list of files
fnames_in[ind], fnames_targ[ind]
('/home/shawley/.fastai/data/SignalTrain_LA2A_Reduced/Val/input_260_.wav',
 '/home/shawley/.fastai/data/SignalTrain_LA2A_Reduced/Val/target_260_LA2A_2c__1__85.wav')
{% endraw %}

Input audio

{% raw %}
waveform, sample_rate = torchaudio.load(fnames_in[ind])
show_audio(waveform, sample_rate)
Shape: (1, 441000), Dtype: torch.float32, Duration: 10.0 s
Max:  0.225,  Min: -0.218, Mean:  0.000, Std Dev:  0.038
{% endraw %}

Target output audio

{% raw %}
target, sr_targ = torchaudio.load(fnames_targ[ind])
show_audio(target, sr_targ)
Shape: (1, 441000), Dtype: torch.float32, Duration: 10.0 s
Max:  0.091,  Min: -0.103, Mean: -0.000, Std Dev:  0.021
{% endraw %}

Let's look at the difference.

Difference

{% raw %}
show_audio(target - waveform, sample_rate)
Shape: (1, 441000), Dtype: torch.float32, Duration: 10.0 s
Max:  0.144,  Min: -0.159, Mean: -0.000, Std Dev:  0.018
{% endraw %} {% raw %}
def get_accompanying_tracks(fn, fn_list, remove=False):
    """"Given one filename, and a list of all filenames, return a list of that filename and 
    any files it 'goes with'
    remove: remove these accompanying files from the main list.
    """
    # make a copies of fn & fn_list with all hyphen+stuff removed. 
    basename = re.sub(r'-[a-zA-Z0-9]+','', fn) 
    basename_list = [re.sub(r'-[a-zA-Z0-9]+','', x) for x in fn_list]
    
    # get indices of all elements of basename_list matching basename, return original filenames
    accompanying = [fn_list[i] for i, x in enumerate(basename_list) if x == basename]
    if remove: 
        for x in accompanying: 
            if x != fn: fn_list.remove(x)  # don't remove the file we search on though
    return accompanying # note accompanying list includes original file too
{% endraw %} {% raw %}
fn_list = ['input_21-0_.wav', 'input_21-1_.wav', 'input_21-hey_.wav', 'input_22_.wav', 'input_23_.wav', 'input_23-toms_.wav', 'input_24-0_.wav', 'input_24-kick_.wav']
print(fn_list)
track = fn_list[1]
print("getting matching tracks for ",track)
tracks  = get_accompanying_tracks(fn_list[1], fn_list, remove=True)
print("Accompanying tracks are: ",tracks)
print("new list = ",fn_list) # should have the extra 21- tracks removed.
['input_21-0_.wav', 'input_21-1_.wav', 'input_21-hey_.wav', 'input_22_.wav', 'input_23_.wav', 'input_23-toms_.wav', 'input_24-0_.wav', 'input_24-kick_.wav']
getting matching tracks for  input_21-1_.wav
Accompanying tracks are:  ['input_21-0_.wav', 'input_21-1_.wav', 'input_21-hey_.wav']
new list =  ['input_21-1_.wav', 'input_22_.wav', 'input_23_.wav', 'input_23-toms_.wav', 'input_24-0_.wav', 'input_24-kick_.wav']
{% endraw %} {% raw %}
fn_list = ['input_21-0_.wav', 'input_21-1_.wav', 'input_21-hey_.wav', 'input_22_.wav', 'input_23_.wav', 'input_23-toms_.wav', 'input_24-0_.wav', 'input_24-kick_.wav']
fn_list_save = fn_list.copy() 
for x in fn_list:
    get_accompanying_tracks(x, fn_list, remove=True)
fn_list, fn_list_save
(['input_21-0_.wav', 'input_22_.wav', 'input_23_.wav', 'input_24-0_.wav'],
 ['input_21-0_.wav',
  'input_21-1_.wav',
  'input_21-hey_.wav',
  'input_22_.wav',
  'input_23_.wav',
  'input_23-toms_.wav',
  'input_24-0_.wav',
  'input_24-kick_.wav'])
{% endraw %}

Dataset class and Dataloaders

here we modify Christian's SignalTrainLA2ADataset class

{% raw %}
class FastProAudioDataset(torch.utils.data.Dataset):
    """ Modifying Steinmetz' micro-tcn code so we can load the kind of multichannel audio we want.
    The difference is that now, we group files that are similar except for a hyphen-designation, 
    e..g. input_235-1_.wav, input_235-2_.wav get read into one tensor.
    
    The 'trick' will be that we only ever store one filename 'version' of a group of files, but whenever we 
    want to try to load that file, we will also grab all its associated files. 
    
    Like SignalTrain LA2A dataset only more general"""
    def __init__(self, root_dir, subset="train", length=16384, preload=False, half=True, fraction=1.0, use_soundfile=False):
        """
        Args:
            root_dir (str): Path to the root directory of the SignalTrain dataset.
            subset (str, optional): Pull data either from "train", "val", "test", or "full" subsets. (Default: "train")
            length (int, optional): Number of samples in the returned examples. (Default: 40)
            preload (bool, optional): Read in all data into RAM during init. (Default: False)
            half (bool, optional): Store the float32 audio as float16. (Default: True)
            fraction (float, optional): Fraction of the data to load from the subset. (Default: 1.0)
            use_soundfile (bool, optional): Use the soundfile library to load instead of torchaudio. (Default: False)
        """
        self.root_dir = root_dir
        self.subset = subset
        self.length = length
        self.preload = preload
        self.half = half
        self.fraction = fraction
        self.use_soundfile = use_soundfile

        if self.subset == "full":
            self.target_files = glob.glob(os.path.join(self.root_dir, "**", "target_*.wav"))
            self.input_files  = glob.glob(os.path.join(self.root_dir, "**", "input_*.wav"))
        else:
            # get all the target files files in the directory first
            self.target_files = glob.glob(os.path.join(self.root_dir, self.subset.capitalize(), "target_*.wav"))
            self.input_files  = glob.glob(os.path.join(self.root_dir, self.subset.capitalize(), "input_*.wav"))

        self.examples = [] 
        self.minutes = 0  # total number of hours of minutes in the subset

        # ensure that the sets are ordered correctlty
        self.target_files.sort()
        self.input_files.sort()

        # get the parameters 
        self.params = [(float(f.split("__")[1].replace(".wav","")), float(f.split("__")[2].replace(".wav",""))) for f in self.target_files]

        
        # SHH: HERE is where we'll package similar hyphen-designated files together. list comprehension here wouldn't be good btw.
        # essentially we are removing 'duplicates'. the first file of each group will be the signifier of all of them
        self.target_files_save, self.input_files_save = self.target_files.copy(), self.input_files.copy() # save a copy of original list
        for x in self.target_files:
            get_accompanying_tracks(x, self.target_files, remove=True)
        for x in self.input_files:
            get_accompanying_tracks(x, self.input_files, remove=True)
        
        # loop over files to count total length
        for idx, (tfile, ifile, params) in enumerate(zip(self.target_files, self.input_files, self.params)):

            ifile_id = int(os.path.basename(ifile).split("_")[1])
            tfile_id = int(os.path.basename(tfile).split("_")[1])
            if ifile_id != tfile_id:
                raise RuntimeError(f"Found non-matching file ids: {ifile_id} != {tfile_id}! Check dataset.")

            md = torchaudio.info(tfile)
            num_frames = md.num_frames

            if self.preload:
                sys.stdout.write(f"* Pre-loading... {idx+1:3d}/{len(self.target_files):3d} ...\r")
                sys.stdout.flush()
                
                input, sr  = self.load_accompanying(ifile, self.input_files_save)
                target, sr = self.load_accompanying(tfile, self.target_files_save)

                num_frames = int(np.min([input.shape[-1], target.shape[-1]]))
                if input.shape[-1] != target.shape[-1]:
                    print(os.path.basename(ifile), input.shape[-1], os.path.basename(tfile), target.shape[-1])
                    raise RuntimeError("Found potentially corrupt file!")
                if self.half:
                    input = input.half()
                    target = target.half()
            else:
                input = None
                target = None

            # create one entry for each patch
            self.file_examples = []
            for n in range((num_frames // self.length)):
                offset = int(n * self.length)
                end = offset + self.length
                self.file_examples.append({"idx": idx, 
                                           "target_file" : tfile,
                                           "input_file" : ifile,
                                           "input_audio" : input[:,offset:end] if input is not None else None,
                                           "target_audio" : target[:,offset:end] if input is not None else None,
                                           "params" : params,
                                           "offset": offset,
                                           "frames" : num_frames})

            # add to overall file examples
            self.examples += self.file_examples
        
        # use only a fraction of the subset data if applicable
        if self.subset == "train":
            classes = set([ex['params'] for ex in self.examples])
            n_classes = len(classes) # number of unique compressor configurations
            fraction_examples = int(len(self.examples) * self.fraction)
            n_examples_per_class = int(fraction_examples / n_classes)
            n_min_total = ((self.length * n_examples_per_class * n_classes) / md.sample_rate) / 60 
            n_min_per_class = ((self.length * n_examples_per_class) / md.sample_rate) / 60 
            print(sorted(classes))
            print(f"Total Examples: {len(self.examples)}     Total classes: {n_classes}")
            print(f"Fraction examples: {fraction_examples}    Examples/class: {n_examples_per_class}")
            print(f"Training with {n_min_per_class:0.2f} min per class    Total of {n_min_total:0.2f} min")

            if n_examples_per_class <= 0: 
                raise ValueError(f"Fraction `{self.fraction}` set too low. No examples selected.")

            sampled_examples = []

            for config_class in classes: # select N examples from each class
                class_examples = [ex for ex in self.examples if ex["params"] == config_class]
                example_indices = np.random.randint(0, high=len(class_examples), size=n_examples_per_class)
                class_examples = [class_examples[idx] for idx in example_indices]
                extra_factor = int(1/self.fraction)
                sampled_examples += class_examples * extra_factor

            self.examples = sampled_examples

        self.minutes = ((self.length * len(self.examples)) / md.sample_rate) / 60 

        # we then want to get the input files
        print(f"Located {len(self.examples)} examples totaling {self.minutes:0.2f} min in the {self.subset} subset.")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        if self.preload:
            audio_idx = self.examples[idx]["idx"]
            offset = self.examples[idx]["offset"]
            input = self.examples[idx]["input_audio"]
            target = self.examples[idx]["target_audio"]
        else:
            offset = self.examples[idx]["offset"] 
            input, sr  = torchaudio.load(self.examples[idx]["input_file"], 
                                        num_frames=self.length, 
                                        frame_offset=offset, 
                                        normalize=False)
            target, sr = torchaudio.load(self.examples[idx]["target_file"], 
                                        num_frames=self.length, 
                                        frame_offset=offset, 
                                        normalize=False)
            if self.half:
                input = input.half()
                target = target.half()

        # at random with p=0.5 flip the phase 
        if np.random.rand() > 0.5:
            input *= -1
            target *= -1

        # then get the tuple of parameters
        params = torch.tensor(self.examples[idx]["params"]).unsqueeze(0)
        params[:,1] /= 100

        return input, target, params

    def load(self, filename):
        if self.use_soundfile:
            x, sr = sf.read(filename, always_2d=True)
            x = torch.tensor(x.T)
        else:
            x, sr = torchaudio.load(filename, normalize=False)
        return x, sr
    
    def load_accompanying(self, filename, filename_list):
        accompanying = get_accompanying_tracks(filename, filename_list, remove=False)
        num_channels = len(accompanying)
        md = torchaudio.info(filename)   # Assume all accompanying tracks are the same shape, etc! 
        num_frames = md.num_frames
        data = torch.empty((num_channels,num_frames))
        for c, afile in enumerate(accompanying):
            data[c], sr  = self.load(afile)
        return data, sr
        
{% endraw %} {% raw %}
class FastProAudioDataset_fastai(FastProAudioDataset):
    "For fastai's sake, have getitem pack the inputs and params together"
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def __getitem__(self, idx):
        input, target, params = super().__getitem__(idx)
        return torch.cat((input,params),dim=-1), target   # pack input and params together
{% endraw %} {% raw %}
class Args(object):  # stand-in for parseargs. these are all micro-tcn defaults
    model_type ='tcn'
    root_dir = str(path)
    preload = False
    sample_rate = 44100
    shuffle = True
    train_subset = 'train'
    val_subset = 'val'
    train_length = 65536
    train_fraction = 1.0
    eval_length = 131072
    batch_size = 8   # original is 32, my laptop needs smaller, esp. w/o half precision
    num_workers = 4
    precision = 32  # LEAVE AS 32 FOR NOW: HALF PRECISION (16) NOT WORKING YET -SHH
    n_params = 2
    
args = Args()

#if args.precision == 16:  torch.set_default_dtype(torch.float16)

# setup the dataloaders
train_dataset = SignalTrainLA2ADataset_fastai(args.root_dir, 
                    subset=args.train_subset, 
                    fraction=args.train_fraction,
                    half=True if args.precision == 16 else False, 
                    preload=args.preload, 
                    length=args.train_length)

train_dataloader = torch.utils.data.DataLoader(train_dataset, 
                    shuffle=args.shuffle,
                    batch_size=args.batch_size,
                    num_workers=args.num_workers,
                    pin_memory=True)

val_dataset = SignalTrainLA2ADataset_fastai(args.root_dir, 
                    preload=args.preload,
                    half=True if args.precision == 16 else False,
                    subset=args.val_subset,
                    length=args.eval_length)

val_dataloader = torch.utils.data.DataLoader(val_dataset, 
                    shuffle=False,
                    batch_size=args.batch_size,
                    num_workers=args.num_workers,
                    pin_memory=True)
[(0.0, 0.0), (0.0, 5.0), (0.0, 15.0), (0.0, 20.0), (0.0, 25.0), (0.0, 30.0), (0.0, 35.0), (0.0, 40.0), (0.0, 45.0), (0.0, 55.0), (0.0, 60.0), (0.0, 65.0), (0.0, 70.0), (0.0, 75.0), (0.0, 80.0), (0.0, 85.0), (0.0, 90.0), (0.0, 95.0), (0.0, 100.0), (1.0, 0.0), (1.0, 5.0), (1.0, 15.0), (1.0, 20.0), (1.0, 25.0), (1.0, 30.0), (1.0, 35.0), (1.0, 40.0), (1.0, 45.0), (1.0, 50.0), (1.0, 55.0), (1.0, 60.0), (1.0, 65.0), (1.0, 75.0), (1.0, 80.0), (1.0, 85.0), (1.0, 90.0), (1.0, 95.0), (1.0, 100.0)]
Total Examples: 396     Total classes: 38
Fraction examples: 396    Examples/class: 10
Training with 0.25 min per class    Total of 9.41 min
Located 380 examples totaling 9.41 min in the train subset.
Located 45 examples totaling 2.23 min in the val subset.
{% endraw %}

If the user requested fp16 precision then we need to install NVIDIA apex:

{% raw %}
if False and args.precision == 16:
    %pip install -q --disable-pip-version-check --no-cache-dir git+https://github.com/NVIDIA/apex
    from apex.fp16_utils import convert_network
{% endraw %}

Define the model(s)

Christian defined a lot of models. We'll do the TCN-300 and the LSTM.

{% raw %}
from microtcn.tcn_bare import TCNModel as TCNModel
#from microtcn.lstm import LSTMModel # actually the LSTM depends on a lot of Lightning stuff, so we'll skip that
from microtcn.utils import center_crop, causal_crop

class TCNModel_fastai(TCNModel):
    "For fastai's sake, unpack the inputs and params"
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def forward(self, x, p=None):
        if (p is None) and (self.nparams > 0):  # unpack the params if needed
            assert len(list(x.size())) == 3   # sanity check 
            x, p = x[:,:,0:-self.nparams], x[:,:,-self.nparams:]
        return super().forward(x, p=p)
{% endraw %} {% raw %}
# micro-tcn defines several different model configurations. I just chose one of them. 
train_configs = [
      {"name" : "TCN-300",
     "model_type" : "tcn",
     "nblocks" : 10,
     "dilation_growth" : 2,
     "kernel_size" : 15,
     "causal" : False,
     "train_fraction" : 1.00,
     "batch_size" : args.batch_size
    }
]

dict_args = train_configs[0]
dict_args["nparams"] = 2

model = TCNModel_fastai(**dict_args)
dtype = torch.float32
{% endraw %}

Let's take a look at the model:

{% raw %}
# this summary allows one to compare the original TCNModel with the TCNModel_fastai
if type(model) == TCNModel_fastai:
    torchsummary.summary(model, [(1,args.train_length)], device="cpu")
else:
    torchsummary.summary(model, [(1,args.train_length),(1,2)], device="cpu")
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                [-1, 1, 16]              48
              ReLU-2                [-1, 1, 16]               0
            Linear-3                [-1, 1, 32]             544
              ReLU-4                [-1, 1, 32]               0
            Linear-5                [-1, 1, 32]           1,056
              ReLU-6                [-1, 1, 32]               0
            Conv1d-7            [-1, 32, 65520]             480
            Linear-8                [-1, 1, 64]           2,112
       BatchNorm1d-9            [-1, 32, 65520]               0
             FiLM-10            [-1, 32, 65520]               0
            PReLU-11            [-1, 32, 65520]              32
           Conv1d-12            [-1, 32, 65534]              32
         TCNBlock-13            [-1, 32, 65520]               0
           Conv1d-14            [-1, 32, 65492]          15,360
           Linear-15                [-1, 1, 64]           2,112
      BatchNorm1d-16            [-1, 32, 65492]               0
             FiLM-17            [-1, 32, 65492]               0
            PReLU-18            [-1, 32, 65492]              32
           Conv1d-19            [-1, 32, 65520]              32
         TCNBlock-20            [-1, 32, 65492]               0
           Conv1d-21            [-1, 32, 65436]          15,360
           Linear-22                [-1, 1, 64]           2,112
      BatchNorm1d-23            [-1, 32, 65436]               0
             FiLM-24            [-1, 32, 65436]               0
            PReLU-25            [-1, 32, 65436]              32
           Conv1d-26            [-1, 32, 65492]              32
         TCNBlock-27            [-1, 32, 65436]               0
           Conv1d-28            [-1, 32, 65324]          15,360
           Linear-29                [-1, 1, 64]           2,112
      BatchNorm1d-30            [-1, 32, 65324]               0
             FiLM-31            [-1, 32, 65324]               0
            PReLU-32            [-1, 32, 65324]              32
           Conv1d-33            [-1, 32, 65436]              32
         TCNBlock-34            [-1, 32, 65324]               0
           Conv1d-35            [-1, 32, 65100]          15,360
           Linear-36                [-1, 1, 64]           2,112
      BatchNorm1d-37            [-1, 32, 65100]               0
             FiLM-38            [-1, 32, 65100]               0
            PReLU-39            [-1, 32, 65100]              32
           Conv1d-40            [-1, 32, 65324]              32
         TCNBlock-41            [-1, 32, 65100]               0
           Conv1d-42            [-1, 32, 64652]          15,360
           Linear-43                [-1, 1, 64]           2,112
      BatchNorm1d-44            [-1, 32, 64652]               0
             FiLM-45            [-1, 32, 64652]               0
            PReLU-46            [-1, 32, 64652]              32
           Conv1d-47            [-1, 32, 65100]              32
         TCNBlock-48            [-1, 32, 64652]               0
           Conv1d-49            [-1, 32, 63756]          15,360
           Linear-50                [-1, 1, 64]           2,112
      BatchNorm1d-51            [-1, 32, 63756]               0
             FiLM-52            [-1, 32, 63756]               0
            PReLU-53            [-1, 32, 63756]              32
           Conv1d-54            [-1, 32, 64652]              32
         TCNBlock-55            [-1, 32, 63756]               0
           Conv1d-56            [-1, 32, 61964]          15,360
           Linear-57                [-1, 1, 64]           2,112
      BatchNorm1d-58            [-1, 32, 61964]               0
             FiLM-59            [-1, 32, 61964]               0
            PReLU-60            [-1, 32, 61964]              32
           Conv1d-61            [-1, 32, 63756]              32
         TCNBlock-62            [-1, 32, 61964]               0
           Conv1d-63            [-1, 32, 58380]          15,360
           Linear-64                [-1, 1, 64]           2,112
      BatchNorm1d-65            [-1, 32, 58380]               0
             FiLM-66            [-1, 32, 58380]               0
            PReLU-67            [-1, 32, 58380]              32
           Conv1d-68            [-1, 32, 61964]              32
         TCNBlock-69            [-1, 32, 58380]               0
           Conv1d-70            [-1, 32, 51212]          15,360
           Linear-71                [-1, 1, 64]           2,112
      BatchNorm1d-72            [-1, 32, 51212]               0
             FiLM-73            [-1, 32, 51212]               0
            PReLU-74            [-1, 32, 51212]              32
           Conv1d-75            [-1, 32, 58380]              32
         TCNBlock-76            [-1, 32, 51212]               0
           Conv1d-77             [-1, 1, 51212]              33
================================================================
Total params: 162,161
Trainable params: 162,161
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.25
Forward/backward pass size (MB): 922.11
Params size (MB): 0.62
Estimated Total Size (MB): 922.98
----------------------------------------------------------------
{% endraw %}

Getting the model into fastai form

Zach Mueller made a very helpful fastai_minima package that we'll use, and follow his instructions.

TODO: Zach says I should either use fastai or fastai_minima, not mix them like I'm about to do. But what I have below is the only thing that works right now. ;-)

{% raw %}
# I guess we could've imported these up at the top of the notebook...
from torch import optim
from fastai_minima.optimizer import OptimWrapper
#from fastai_minima.learner import Learner  # this doesn't include lr_find()
from fastai.learner import Learner
from fastai_minima.learner import DataLoaders
#from fastai_minima.callback.training_utils import CudaCallback, ProgressCallback # note sure if I need these
{% endraw %} {% raw %}
def opt_func(params, **kwargs): return OptimWrapper(optim.SGD(params, **kwargs))

dls = DataLoaders(train_dataloader, val_dataloader)
{% endraw %}

Checking: Let's make sure the Dataloaders are working

{% raw %}
if args.precision==16: 
    dtype = torch.float16
    model = convert_network(model, torch.float16)

model = model.to('cuda:0')
if type(model) == TCNModel_fastai:
    print("We're using Hawley's modified code")
    packed, targ = dls.one_batch()
    inp, params = packed[:,:,0:-dict_args['nparams']], packed[:,:,-dict_args['nparams']:]
    pred = model.forward(packed.to('cuda:0', dtype=dtype))
else:
    print("We're using Christian's version of Dataloader and model")
    inp, targ, params = dls.one_batch()
    pred = model.forward(inp.to('cuda:0',dtype=dtype), p=params.to('cuda:0', dtype=dtype))
print(f"input  = {inp.size()}\ntarget = {targ.size()}\nparams = {params.size()}\npred   = {pred.size()}")
We're using Hawley's modified code
input  = torch.Size([8, 1, 65536])
target = torch.Size([8, 1, 65536])
params = torch.Size([8, 1, 2])
pred   = torch.Size([8, 1, 51214])
{% endraw %}

We can make the pred and target the same length by cropping when we compute the loss:

{% raw %}
class Crop_Loss:
    "Crop target size to match preds"
    def __init__(self, axis=-1, causal=False, reduction="mean", func=nn.L1Loss):
        store_attr()
        self.loss_func = func()
    def __call__(self, pred, targ):
        targ = causal_crop(targ, pred.shape[-1]) if self.causal else center_crop(targ, pred.shape[-1])
        #pred, targ = TensorBase(pred), TensorBase(targ)
        assert pred.shape == targ.shape, f'pred.shape = {pred.shape} but targ.shape = {targ.shape}'
        return self.loss_func(pred,targ).flatten().mean() if self.reduction == "mean" else loss(pred,targ).flatten().sum()
    

# we could add a metric like MSE if we want
def crop_mse(pred, targ, causal=False): 
    targ = causal_crop(targ, pred.shape[-1]) if causal else center_crop(targ, pred.shape[-1])
    return ((pred - targ)**2).mean()
{% endraw %}

Enable logging with WandB:

{% raw %}
wandb.login()
wandb: Currently logged in as: drscotthawley (use `wandb login --relogin` to force relogin)
True
{% endraw %}

Define the fastai Learner and callbacks

We're going to add a new custom WandBAudio callback futher below, that we'll uses when we call fit().

WandBAudio Callback

In order to log audio samples, let's write our own audio-logging callback for fastai:

{% raw %}
class WandBAudio(Callback):
    """Progress-like callback: log audio to WandB"""
    order = ProgressCallback.order+1
    def __init__(self, n_preds=5, sample_rate=44100):
        store_attr()

    def after_epoch(self):  
        if not self.learn.training:
            with torch.no_grad():
                preds, targs = [x.detach().cpu().numpy().copy() for x in [self.learn.pred, self.learn.y]]
            log_dict = {}
            for i in range(min(self.n_preds, preds.shape[0])): # note wandb only supports mono
                    log_dict[f"preds_{i}"] = wandb.Audio(preds[i,0,:], caption=f"preds_{i}", sample_rate=self.sample_rate)
            wandb.log(log_dict)
{% endraw %}

Learner and wandb init

{% raw %}
wandb.init(project='micro-tcn-fastai')#  no name, name=json.dumps(dict_args))

learn = Learner(dls, model, loss_func=Crop_Loss(), metrics=crop_mse, opt_func=opt_func,
               cbs= [WandbCallback()])
{% endraw %}

Train the model

We can use the fastai learning rate finder to suggest a learning rate:

{% raw %}
learn.lr_find(end_lr=0.1) 
SuggestedLRs(valley=0.0006918309954926372)
{% endraw %}

And now we'll train using the one-cycle LR schedule, with the WandBAudio callback. (Ignore any warning messages)

{% raw %}
epochs = 20  # change to 50 for better results but a longer wait
learn.fit_one_cycle(epochs, lr_max=3e-3, cbs=WandBAudio(sample_rate=args.sample_rate))
Could not gather input dimensions
WandbCallback requires use of "SaveModelCallback" to log best model
WandbCallback was not able to prepare a DataLoader for logging prediction samples -> 
epoch train_loss valid_loss crop_mse time
0 0.143242 0.098410 0.020299 00:06
1 0.096335 0.061745 0.007963 00:05
2 0.065788 0.035349 0.003570 00:05
3 0.045120 0.027977 0.001921 00:05
4 0.034311 0.023991 0.001443 00:05
5 0.026962 0.020367 0.001035 00:06
6 0.023846 0.020088 0.000883 00:05
7 0.021708 0.015346 0.000704 00:06
8 0.019866 0.026435 0.001117 00:06
9 0.017529 0.012842 0.000533 00:05
10 0.016500 0.013006 0.000504 00:05
11 0.015390 0.011723 0.000425 00:06
12 0.014275 0.012459 0.000437 00:06
13 0.013890 0.012470 0.000408 00:05
14 0.013401 0.013570 0.000454 00:05
15 0.012933 0.011421 0.000390 00:06
16 0.012545 0.010564 0.000362 00:05
17 0.012153 0.011395 0.000392 00:05
18 0.011879 0.010478 0.000356 00:05
19 0.011740 0.010412 0.000361 00:06
{% endraw %} {% raw %}
wandb.finish() # call wandb.finish() after training or your logs may be incomplete

Waiting for W&B process to finish, PID 1852379... (success).

Run history:


crop_mse█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dampening_0▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
lr_0▁▂▂▃▄▅▆▇███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁
mom_0██▇▆▅▄▃▂▁▁▁▁▁▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▇▇▇▇█████
nesterov_0▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
raw_loss█▆▄▃▃▃▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss█▇▅▅▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_loss█▅▃▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁
wd_0▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

Run summary:


crop_mse0.00036
dampening_00
epoch20
lr_00.0
mom_00.95
nesterov_0False
raw_loss0.01222
train_loss0.01174
valid_loss0.01041
wd_00
Synced 5 W&B file(s), 100 media file(s), 0 artifact file(s) and 0 other file(s)
Synced fresh-salad-56: https://wandb.ai/drscotthawley/micro-tcn-fastai/runs/9w1h46em
Find logs at: ./wandb/run-20211025_091818-9w1h46em/logs
{% endraw %} {% raw %}
learn.save('micro-tcn-fastai')
Path('models/micro-tcn-fastai.pth')
{% endraw %}

Go check out the resulting run logs, graphs, and audio samples at https://wandb.ai/drscotthawley/micro-tcn-fastai, or... lemme see if I can embed some results below:

...ok it looks like the WandB results iframe (with cool graphs & audio) is getting filtered out of the docs (by nbdev and/or jekyll), but if you open this notebook file -- e.g. click the "Open in Colab" badge at the top -- then scroll down and you'll see the report. Or just go to the WandB link posted above!

TODO: Inference / Evaluation / Analysis

Load in the testing data

{% raw %}
test_dataset = SignalTrainLA2ADataset_fastai(args.root_dir, 
                    preload=args.preload,
                    half=True if args.precision == 16 else False,
                    subset='test',
                    length=args.eval_length)

test_dataloader = torch.utils.data.DataLoader(test_dataset, 
                    shuffle=False,
                    batch_size=args.batch_size,
                    num_workers=args.num_workers,
                    pin_memory=True)

learn = Learner(dls, model, loss_func=Crop_Loss(), metrics=crop_mse, opt_func=opt_func, cbs=[])
learn.load('micro-tcn-fastai')
Located 9 examples totaling 0.45 min in the test subset.
<fastai.learner.Learner at 0x7fcff933b430>
{% endraw %}

^^ 9 examples? I thought there were only 3:

{% raw %}
!ls {path}/Test
input_235_.wav	input_259_.wav		       target_256_LA2A_2c__1__65.wav
input_256_.wav	target_235_LA2A_2c__0__65.wav  target_259_LA2A_2c__1__80.wav
{% endraw %}

...Ok I don't understand that yet. Moving on:

Let's get some predictions from the model. Note that the length of these predictions will greater than in training, because we specified them differently:

{% raw %}
print(args.train_length, args.eval_length)
65536 131072
{% endraw %}

Handy routine to grab some data and run it through the model to get predictions:

{% raw %}
def get_pred_batch(dataloader, crop_target=True, causal=False):
    packed, target = next(iter(dataloader))
    input, params = packed[:,:,0:-dict_args['nparams']], packed[:,:,-dict_args['nparams']:]
    pred = model.forward(packed.to('cuda:0', dtype=dtype))
    if crop_target: target = causal_crop(target, pred.shape[-1]) if causal else center_crop(target, pred.shape[-1])
    input, params, target, pred = [x.detach().cpu() for x in [input, params, target, pred]]
    return input, params, target, pred
{% endraw %} {% raw %}
input, params, target, pred = get_pred_batch(test_dataloader, causal=dict_args['causal'])
i = 0  # just look at the first element
print(f"------- i = {i} ---------\n")
print(f"prediction:")
show_audio(pred[i], sample_rate)
------- i = 0 ---------

prediction:
Shape: (1, 116750), Dtype: torch.float32, Duration: 2.647392290249433 s
Max:  0.139,  Min: -0.147, Mean:  0.000, Std Dev:  0.037
{% endraw %} {% raw %}
print(f"target:")
show_audio(target[i], sample_rate)
target:
Shape: (1, 116750), Dtype: torch.float32, Duration: 2.647392290249433 s
Max:  0.215,  Min: -0.202, Mean:  0.000, Std Dev:  0.053
{% endraw %}

TODO: More. We're not finished. I'll come back and add more to this later.

Deployment / Plugins

Check out Christian's GitHub page for micro-tcn where he provides instructions and JUCE files by which to render the model as an audio plugin. Pretty sure you can only do this with the causal models, which I didn't include -- yet!